package pers.xdrodger.simple.test.ppt.poi.practice;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.apache.poi.ss.util.CellRangeAddress;
import org.apache.poi.util.Units;
import org.apache.poi.xddf.usermodel.chart.XDDFChartData;
import org.apache.poi.xddf.usermodel.chart.XDDFDataSource;
import org.apache.poi.xddf.usermodel.chart.XDDFDataSourcesFactory;
import org.apache.poi.xddf.usermodel.chart.XDDFNumericalDataSource;
import org.apache.poi.xslf.usermodel.*;
import org.junit.Test;
import pers.xdrodger.simple.test.ppt.poi.FileUtil;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class UpdateDoughnutChart {

    public XSLFChart getChart(XMLSlideShow ppt, String shapeName) throws Exception {
        for (XSLFSlide slide : ppt.getSlides()) {
            for (XSLFShape shape : slide.getShapes()) {
                if (shape.getShapeName().equals(shapeName) && shape instanceof XSLFGraphicFrame && !(shape instanceof XSLFTable)) {
                    System.out.println(shape.getShapeName());
                    XSLFGraphicFrame graphicFrame = (XSLFGraphicFrame) shape;
                    if (!graphicFrame.hasChart()) {
                        continue;
                    }
                    XSLFChart chart = graphicFrame.getChart();
                    return chart;
                }
            }
        }
        return null;
    }

    @Test
    public void updateDoughnutChart() throws Exception {
        FileInputStream fis = new FileInputStream(FileUtil.getInputFilePath() + "ppt-demo.pptx");
        XMLSlideShow ppt = new XMLSlideShow(fis);
        List<ChartData> dataList = getDataList();


        // Category Axis Data
        List<String> deptList = dataList.stream().map(ChartData::getName).collect(Collectors.toList());
        String[] categories = deptList.toArray(new String[0]);
        // Values
        List<Double> eeiList = dataList.stream().map(ChartData::getValue).collect(Collectors.toList());
        List<Double> normList = dataList.stream().map(ChartData::getValue2).collect(Collectors.toList());
        List<Double> companyList = dataList.stream().map(ChartData::getValue3).collect(Collectors.toList());
        Double[] values1 = eeiList.toArray(new Double[0]);
        Double[] values2 = normList.toArray(new Double[0]);
        Double[] values3 = companyList.toArray(new Double[0]);
        XSLFChart chart = getChart(ppt, "update-doughnut-chart");

        final int numOfPoints = categories.length;
        final String categoryDataRange = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_CATEGORY, COLUMN_CATEGORY));
        final String valuesDataRange = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_SERIES_1, COLUMN_SERIES_1));
        final String valuesDataRange2 = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_SERIES_2, COLUMN_SERIES_2));
        final String valuesDataRange3 = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_SERIES_3, COLUMN_SERIES_3));

        System.out.println(valuesDataRange);
        System.out.println(valuesDataRange2);
        System.out.println(valuesDataRange3);
        final XDDFDataSource<?> categoriesData = XDDFDataSourcesFactory.fromArray(categories, categoryDataRange, COLUMN_CATEGORY);
        final XDDFNumericalDataSource<? extends Number> valuesData = XDDFDataSourcesFactory.fromArray(values1, valuesDataRange, COLUMN_SERIES_1);
        valuesData.setFormatCode("0.00%");
//        values1[6] = 16.0; // if you ever want to change the underlying data, it has to be done before building the data source
        final XDDFNumericalDataSource<? extends Number> valuesData2 = XDDFDataSourcesFactory.fromArray(values2, valuesDataRange2, COLUMN_SERIES_2);
        valuesData2.setFormatCode("0.00%");
        final XDDFNumericalDataSource<? extends Number> valuesData3 = XDDFDataSourcesFactory.fromArray(values3, valuesDataRange3, COLUMN_SERIES_3);
//        valuesData2.setFormatCode("General");
        valuesData3.setFormatCode("0.00%");

        List<ChartSeriesData> chartSeriesDataList = new ArrayList<>();
        ChartSeriesData dimensionScoreSeriesData = new ChartSeriesData();
        dimensionScoreSeriesData.setSeriesName("维度得分");
        dimensionScoreSeriesData.setDataRangeReference(valuesDataRange);
        dimensionScoreSeriesData.setDataList(eeiList);
        dimensionScoreSeriesData.setValueData(valuesData);
        chartSeriesDataList.add(dimensionScoreSeriesData);
        ChartSeriesData normSeriesData = new ChartSeriesData();
        normSeriesData.setSeriesName("全行业平均分");
        normSeriesData.setDataRangeReference(valuesDataRange2);
        normSeriesData.setDataList(normList);
        chartSeriesDataList.add(normSeriesData);
        normSeriesData.setValueData(valuesData2);
        ChartSeriesData companySeriesData = new ChartSeriesData();
        companySeriesData.setSeriesName("公司");
        companySeriesData.setDataRangeReference(valuesDataRange3);
        companySeriesData.setDataList(companyList);
        companySeriesData.setValueData(valuesData3);
        chartSeriesDataList.add(companySeriesData);

        final List<XDDFChartData> chartDataList = chart.getChartSeries();
        for (XDDFChartData chartData : chartDataList) {
            for (int i = 0; i < chartData.getSeriesCount(); i ++) {
                XDDFChartData.Series series = chartData.getSeries(i);
                String dataRangeReference = series.getValuesData().getDataRangeReference();
                ChartSeriesData chartSeriesData = chartSeriesDataList.stream().filter(n -> removeDigit(n.getDataRangeReference()).equals((removeDigit(dataRangeReference)))).findFirst().orElse(null);
                series.replaceData(categoriesData, chartSeriesData.getValueData());
            }
            chart.plot(chartData);
        }

//        chart.setTitleText(chartTitle); // https://stackoverflow.com/questions/30532612
//         chart.setTitleOverlay(overlay);

        // adjust font size for readability
//        bar.getCategoryAxis().getOrAddTextProperties().setFontSize(11.5);
//        chart.getTitle().getOrAddTextProperties().setFontSize(18.2);


        // save the result
        try (OutputStream out = new FileOutputStream(FileUtil.getOutputFilePath() + "update-doughnut-chart.pptx")) {
            ppt.write(out);
        }
    }

    @Test
    public void updateDoughnutChart2() throws Exception {
        FileInputStream fis = new FileInputStream(FileUtil.getInputFilePath() + "ppt-demo.pptx");
        XMLSlideShow ppt = new XMLSlideShow(fis);
        List<ChartData> dataList = getDataList();


        // Category Axis Data
        List<String> deptList = dataList.stream().map(ChartData::getName).collect(Collectors.toList());
        String[] categories = deptList.toArray(new String[0]);
        // Values
        List<Double> eeiList = dataList.stream().map(ChartData::getValue).collect(Collectors.toList());
        Double[] values1 = eeiList.toArray(new Double[0]);
        XSLFChart chart = getChart(ppt, "update-doughnut-chart");

        final int numOfPoints = categories.length;
        final String categoryDataRange = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_CATEGORY, COLUMN_CATEGORY));
        final String valuesDataRange = chart.formatRange(new CellRangeAddress(1, numOfPoints, COLUMN_SERIES_1, COLUMN_SERIES_1));

        System.out.println(valuesDataRange);
        final XDDFDataSource<?> categoriesData = XDDFDataSourcesFactory.fromArray(categories, categoryDataRange, COLUMN_CATEGORY);
        final XDDFNumericalDataSource<? extends Number> valuesData = XDDFDataSourcesFactory.fromArray(values1, valuesDataRange, COLUMN_SERIES_1);
        valuesData.setFormatCode("General");
//        values1[6] = 16.0; // if you ever want to change the underlying data, it has to be done before building the data source

        List<ChartSeriesData> chartSeriesDataList = new ArrayList<>();
        ChartSeriesData dimensionScoreSeriesData = new ChartSeriesData();
        dimensionScoreSeriesData.setSeriesName("层级");
        dimensionScoreSeriesData.setColumn(1);
        dimensionScoreSeriesData.setDataRangeReference(valuesDataRange);
        dimensionScoreSeriesData.setDataList(eeiList);
        dimensionScoreSeriesData.setValueData(valuesData);
        chartSeriesDataList.add(dimensionScoreSeriesData);


        final List<XDDFChartData> chartDataList = chart.getChartSeries();
        for (XDDFChartData chartData : chartDataList) {
            List<String> dataRangeReferenceList = new ArrayList<>();
            int seriesCount = chartData.getSeriesCount();
            for (int i = 0; i < seriesCount; i ++) {
                XDDFChartData.Series series = chartData.getSeries(i);
                String dataRangeReference = series.getValuesData().getDataRangeReference();
                dataRangeReferenceList.add(dataRangeReference);
            }
            for (int i = seriesCount -1; i >= 0; i --) {
                chartData.removeSeries(i);
            }
            for (String dataRangeReference : dataRangeReferenceList) {
                ChartSeriesData chartSeriesData = chartSeriesDataList.stream().filter(n -> removeDigit(n.getDataRangeReference()).equals((removeDigit(dataRangeReference)))).findFirst().orElse(null);
                XDDFChartData.Series newSeries = chartData.addSeries(categoriesData, chartSeriesData.getValueData());
                newSeries.setTitle(chartSeriesData.getSeriesName(), chart.setSheetTitle(chartSeriesData.getSeriesName(), chartSeriesData.getColumn()));
            }
            chart.plot(chartData);
        }

//        chart.setTitleText(chartTitle); // https://stackoverflow.com/questions/30532612
//         chart.setTitleOverlay(overlay);

        // adjust font size for readability
//        bar.getCategoryAxis().getOrAddTextProperties().setFontSize(11.5);
//        chart.getTitle().getOrAddTextProperties().setFontSize(18.2);


        // save the result
        try (OutputStream out = new FileOutputStream(FileUtil.getOutputFilePath() + "update-doughnut-chart.pptx")) {
            ppt.write(out);
        }
    }

    public String removeDigit(String text) {
        if (StringUtils.isBlank(text)) {
            return "";
        }
        Pattern p = Pattern.compile("[\\d]");
        Matcher matcher = p.matcher(text);
        String result = matcher.replaceAll("");
        return result;
    }

    private static int fromCM(double cm) {
        return (int) (Math.rint(cm * Units.EMU_PER_CENTIMETER));
    }

    private List<ChartData> getDataList() {
        List<ChartData> result = new ArrayList<>();
        result.add(new ChartData("高层2", 16.0, null, null));
        result.add(new ChartData("中层", 46.0, null, null));
        result.add(new ChartData("基层", 589.0, null, null));
        result.add(new ChartData("新增加", 20.0, null, null));
        return result;
    }

    @Data
    private class ChartSeriesData {
        private String seriesName;
        private int column;
        private String dataRangeReference;
        private List<Double> dataList;
        private XDDFNumericalDataSource<? extends Number> valueData;
    }



    @AllArgsConstructor
    @Data
    private class ChartData {
        private String name;
        private Double value;
        private Double value2;
        private Double value3;

    }

    private static final int COLUMN_CATEGORY = 0;
    private static final int COLUMN_SERIES_1 = 1;
    private static final int COLUMN_SERIES_2 = 2;
    private static final int COLUMN_SERIES_3 = 3;
}
